import os
import pprint
import time
import threading
import torch as th
from types import SimpleNamespace as SN
from utils.logging import Logger
from utils.timehelper import time_left, time_str
from os.path import dirname, abspath
import copy
import json
import shutil

from learners.multi_task import REGISTRY as le_REGISTRY
from runners.multi_task import REGISTRY as r_REGISTRY
from controllers.multi_task import REGISTRY as mac_REGISTRY
from components.episode_buffer import ReplayBuffer
from components.offline_buffer import OfflineBuffer, DataSaver, OfflineSample
from components.transforms import OneHot

import numpy as np
from copy import deepcopy
from modules.decomposers import REGISTRY as decomposer_REGISTRY


def run(_run, _config, _log):
    _config = args_sanity_check(_config, _log)

    args = SN(**_config)
    args.device = "cuda" if args.use_cuda else "cpu"

    logger = Logger(_log)

    _log.info("Experiment Parameters:")
    experiment_params = pprint.pformat(_config,
                                       indent=4,
                                       width=1)
    _log.info("\n\n" + experiment_params + "\n")


    if args.pretrain_vae:
        results_save_dir = args.vae_pretrain_save_dir
    elif args.pretrain_vqvae:
        results_save_dir = args.vqvae_pretrain_save_dir
    elif args.train_DT_w_glsk:
        results_save_dir = args.dt_w_glsk_results_save_dir

    if args.use_tensorboard and not args.evaluate:
        tb_exp_direc = os.path.join(results_save_dir, 'tb_logs')
        logger.setup_tb(tb_exp_direc)

    args.save_dir = os.path.join(results_save_dir, 'models')

    config_str = json.dumps(vars(args), indent=4)
    with open(os.path.join(results_save_dir, "config.json"), "w") as f:
        f.write(config_str)

    logger.setup_sacred(_run)

    run_sequential(args=args, logger=logger)

    print("Exiting Main")

    print("Stopping all threads")
    for t in threading.enumerate():
        if t.name != "MainThread":
            print("Thread {} is alive! Is daemon: {}".format(t.name, t.daemon))
            t.join(timeout=1)
            print("Thread joined")

    print("Exiting script")
    os._exit(os.EX_OK)


def evaluate_sequential(main_args, logger, task2runner):
    n_test_runs = max(1, main_args.test_nepisode // main_args.batch_size_run)
    with th.no_grad():
        for task in main_args.test_tasks:
            for _ in range(n_test_runs):
                task2runner[task].run(test_mode=True)

            if main_args.save_replay:
                task2runner[task].save_replay()

            task2runner[task].close_env()

    logger.log_stat("episode", 0, 0)
    logger.print_recent_stats()


def init_tasks(task_list, main_args, logger):
    task2args, task2runner, task2buffer = {}, {}, {}
    task2scheme, task2groups, task2preprocess = {}, {}, {}

    for task in task_list:
        # define task_args
        task_args = copy.deepcopy(main_args)
        task_args.env_args["map_name"] = task
        task2args[task] = task_args

        task_runner = r_REGISTRY[main_args.runner](args=task_args, logger=logger, task=task)
        task2runner[task] = task_runner

        env_info = task_runner.get_env_info()
        for k, v in env_info.items():
            setattr(task_args, k, v)

        scheme = {
            "state": {"vshape": env_info["state_shape"]},
            "obs": {"vshape": env_info["obs_shape"], "group": "agents"},
            "actions": {"vshape": (1,), "group": "agents", "dtype": th.long},
            "avail_actions": {"vshape": (env_info["n_actions"],), "group": "agents", "dtype": th.int},
            "reward": {"vshape": (1,)},
            "terminated": {"vshape": (1,), "dtype": th.uint8},
            
            "skill":{"vshape": (main_args.skill_dim,), "group": "agents"},
            "rtg": {"vshape": (1,), "group": "agents"},
            "global_skill_id":{"vshape": (1,), "group": "agents"}
        }
        
        groups = {
            "agents": task_args.n_agents
        }
        preprocess = {
            "actions": ("actions_onehot", [OneHot(out_dim=task_args.n_actions)])
        }

        task2buffer[task] = ReplayBuffer(scheme, groups, 1, env_info["episode_limit"] + 1,
                                         preprocess=preprocess,
                                         device="cpu" if task_args.buffer_cpu_only else task_args.device)

        task2scheme[task], task2groups[task], task2preprocess[task] = scheme, groups, preprocess

    return task2args, task2runner, task2buffer, task2scheme, task2groups, task2preprocess
                

def run_sequential(args, logger):
    args.n_tasks = len(args.train_tasks)
    main_args = copy.deepcopy(args)
    if getattr(main_args, "pretrain", False):
        all_tasks = list(set(args.train_tasks + args.test_tasks + args.pretrain_tasks))
    else:
        all_tasks = list(set(args.train_tasks + args.test_tasks))

    task2args, task2runner, task2buffer, task2scheme, task2groups, task2preprocess = init_tasks(all_tasks, main_args,
                                                                                                logger)
    task2buffer_scheme = {task: task2buffer[task].scheme for task in all_tasks}

    mac = mac_REGISTRY[main_args.mac](train_tasks=all_tasks, task2scheme=task2buffer_scheme, task2args=task2args,
                                      main_args=main_args)
    for task in main_args.test_tasks:
        task2runner[task].setup(scheme=task2scheme[task], groups=task2groups[task], preprocess=task2preprocess[task],
                                mac=mac)
    learner = le_REGISTRY[main_args.learner](mac, logger, main_args)
    if main_args.use_cuda:
        learner.cuda()

    if main_args.checkpoint_path != "":
        timesteps = []
        timestep_to_load = 0

        if not os.path.isdir(main_args.checkpoint_path):
            logger.console_logger.info("Checkpoint directiory {} doesn't exist".format(main_args.checkpoint_path))
            return

        for name in os.listdir(main_args.checkpoint_path):
            full_name = os.path.join(main_args.checkpoint_path, name)
            if os.path.isdir(full_name) and name.isdigit():
                timesteps.append(int(name))

        if main_args.load_step == 0:
            # choose the max timestep
            timestep_to_load = max(timesteps)
        else:
            timestep_to_load = min(timesteps, key=lambda x: abs(x - main_args.load_step))

        model_path = os.path.join(main_args.checkpoint_path, str(timestep_to_load))

        logger.console_logger.info("Loading model from {}".format(model_path))

        learner.load_models(model_path)

        if main_args.evaluate or main_args.save_replay:
            evaluate_sequential(main_args, logger, task2runner)
            return

    
    if getattr(main_args, "pretrain_vae") == True:
        if getattr(main_args, "pretrain", False):
            task2offlinedata = {}
            for task in main_args.pretrain_tasks:
                task2offlinedata[task] = OfflineBuffer(task, main_args.pretrain_tasks_data_quality[task],
                                                    data_folder=main_args.offline_data_name,
                                                    dataset_folder=main_args.offline_data_folder,
                                                    offline_data_size=args.offline_data_size,
                                                    random_sample=args.offline_data_shuffle)

            test_task2offlinedata = None
            if hasattr(learner, 'test_pretrain') and hasattr(main_args, 'test_tasks_data_quality'):
                test_task2offlinedata = {}
                for task in main_args.test_tasks_data_quality.keys():
                    test_task2offlinedata[task] = OfflineBuffer(task, main_args.test_tasks_data_quality[task],
                                                                data_folder=main_args.offline_data_name,
                                                                offline_data_size=args.offline_data_size,
                                                                random_sample=args.offline_data_shuffle)

            logger.console_logger.info(
                "Beginning pre-training with {} timesteps for each task".format(main_args.pretrain_steps))
            train_sequential(main_args.pretrain_tasks, main_args, logger, learner, task2args, task2runner, task2offlinedata, 0,
                            pretrain=True, test_task2offlinedata=test_task2offlinedata)
            
            logger.console_logger.info(f"Finished pretraining")
            test_task2offlinedata = None  
            
            save_path = os.path.join(main_args.save_dir, str(main_args.pretrain_steps))
            os.makedirs(save_path, exist_ok=True)
            logger.console_logger.info("Saving models to {}".format(save_path))
            learner.save_models(save_path)
            
            logger.console_logger.info(f"Finished Pretrain VAE Training")
            return

        elif hasattr(main_args, "pretrain"):
            load_path = os.path.join(main_args.pretrain_save_dir, str(main_args.pretrain_steps))
            learner.load_models(load_path)
            logger.console_logger.info("Load pretrained models from {}".format(load_path))
    
    if getattr(main_args, "pretrain_vqvae") == True:
        task2offlinedata4dt = {}
        task2decomposer = {}
        for task in main_args.train_tasks:
            task2offlinedata4dt[task] = OfflineBuffer(task, main_args.train_tasks_data_quality[task]+"/skill_dim_"+str(main_args.skill_dim)+"/"+ main_args.pretrain_skill_time,
                                                        dataset_folder= main_args.offline_data_folder+"_debug" \
                                                                if main_args.debug else main_args.offline_data_folder, 
                                                        data_folder=main_args.offline_data_name,
                                                        offline_data_size=args.offline_data_size,
                                                        random_sample=args.offline_data_shuffle)
            task_args = task2args[task]
            task_decomposer = decomposer_REGISTRY[task_args.env](task_args)
            task2decomposer[task] = task_decomposer
    
        logger.console_logger.info(
                "Beginning pre-training vqvae with {} timesteps for each task".format(main_args.pretrain_steps))
        train_sequential(main_args.train_tasks, main_args, logger, learner, task2args, task2runner, task2offlinedata4dt,task2decomposer)

        
        save_path = os.path.join(main_args.save_dir, str(main_args.pretrain_steps))
        os.makedirs(save_path, exist_ok=True)
        learner.save_models(save_path)
    
    
    if getattr(main_args, "train_DT_w_glsk"):
        task2offlinedata4dt = {}
        task2decomposer = {}
        task2codebook = {}
        
        for task in main_args.train_tasks:
            task2offlinedata4dt[task] = OfflineBuffer(task, main_args.train_tasks_data_quality[task]+"/gl_skill_dim_"+str(main_args.vqvae_K)+"/"+ main_args.train_global_skill_time,
                                                        dataset_folder= main_args.offline_data_folder4_gl_skill+"_debug" \
                                                                if main_args.debug else main_args.offline_data_folder4_gl_skill, 
                                                        data_folder=str(main_args.offline_data_name),
                                                        offline_data_size=args.offline_data_size,
                                                        random_sample=args.offline_data_shuffle)
            

            task_args = task2args[task]
            task_decomposer = decomposer_REGISTRY[task_args.env](task_args)
            task2decomposer[task] = task_decomposer
            
        logger.console_logger.info( 
            "[DT Training] Beginning multi-task offline training with {} timesteps for each task".format(main_args.t_max))
        train_sequential(main_args.train_tasks, main_args, logger, learner, task2args, task2runner, task2offlinedata4dt, task2decomposer)
        
        return
    

    if main_args.save_model:
        save_path = os.path.join(main_args.save_dir, str(main_args.t_max))
        os.makedirs(save_path, exist_ok=True)
        logger.console_logger.info("Saving final models to {}".format(save_path))
        learner.save_models(save_path)

    for task in args.test_tasks:
        task2runner[task].close_env()
    logger.console_logger.info(f"Finished Training")


def train_sequential(train_tasks, main_args, logger, learner, task2args, task2runner, task2offlinedata, task2decomposer,t_start=0,
                     pretrain=False, test_task2offlinedata=None):
    t_env = t_start
    episode = 0  
    t_max = main_args.t_max if not pretrain else main_args.pretrain_steps
    model_save_time = 0
    last_test_T = 0
    last_log_T = 0
    start_time = time.time()
    last_time = start_time
    test_time_total = 0
    test_start_time = 0

    batch_size_train = main_args.batch_size
    batch_size_run = main_args.batch_size_run

    n_test_runs = max(1, main_args.test_nepisode // batch_size_run)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                              
    test_start_time = time.time()

    if main_args.test_pretrain:
        with th.no_grad():
            for task in main_args.test_tasks:
                task2runner[task].t_env = t_env
                for _ in range(n_test_runs):
                    task2runner[task].run(test_mode=True, pretrain=pretrain)

            if pretrain and test_task2offlinedata is not None:
                for task, data_buffer in test_task2offlinedata.items():
                    episode_sample = data_buffer.sample(batch_size_train * 3)

                    if episode_sample.device != task2args[task].device:
                        episode_sample.to(task2args[task].device)

                    if hasattr(learner, 'test_pretrain'):
                        learner.test_pretrain(episode_sample, t_env, episode, task)
                    else:
                        raise ValueError("Do test_pretrain with a learner that does not have a `test_pretrain` method!")

        test_time_total += time.time() - test_start_time

    while t_env < t_max:
        np.random.shuffle(train_tasks)
        for task in train_tasks:

            episode_sample = task2offlinedata[task].sample(batch_size_train)

            if episode_sample.device != task2args[task].device:
                episode_sample.to(task2args[task].device)

            if pretrain: 
                if hasattr(learner, 'pretrain'):
                    terminated = learner.pretrain(episode_sample, t_env, episode, task)
                else:
                    raise ValueError("Do pretraining with a learner that does not have a `pretrain` method!")
            
            elif main_args.pretrain_vqvae:  
                terminated = learner.pretrain_vqvae(episode_sample, t_env, episode, task, logger)

            elif main_args.train_DT_w_glsk: 
                terminated = learner.train_by_skill(episode_sample, t_env, episode, task, logger, task2decomposer)

                
            if terminated is not None and terminated:
                break

            t_env += 1
            episode += batch_size_run
        
        if main_args.pretrain_vae:
            learner.update(pretrain=pretrain)
        
        if terminated is not None and terminated:
            logger.console_logger.info(f"Terminate training by the learner at t_env = {t_env}. Finish training.")
            break
        
        if pretrain and (t_env >= t_max):
            save_skill(learner, train_tasks, task2offlinedata, main_args, logger, t_max)
        
        if main_args.pretrain_vqvae and ((t_env/3) % main_args.save_dataset_interval==0):
            save_gloabl_skill_from_vqvae(learner, train_tasks, task2offlinedata, main_args, logger, t_env)
            
        if not main_args.pretrain_vqvae:
            if (not pretrain) and ((t_env - last_test_T) / main_args.test_interval >= 1 or t_env >= t_max): 
                test_start_time = time.time()

                with th.no_grad():
                    for task in main_args.test_tasks:
                        # if task =='5m_vs_6m':
                            task2runner[task].t_env = t_env
                            for _ in range(n_test_runs):
                                task2runner[task].run(test_mode=True, pretrain=False)

                test_time_total += time.time() - test_start_time

                logger.console_logger.info("Step: {} / {}".format(t_env, t_max))
                logger.console_logger.info("Estimated time left: {}. Time passed: {}. Test time cost: {}".format(
                    time_left(last_time, last_test_T, t_env, t_max), time_str(time.time() - start_time),
                    time_str(test_time_total)
                ))
                last_time = time.time()
                last_test_T = t_env
                
                logger.log_stat("episode", 0, 0)
                logger.print_recent_stats()
   
        if main_args.save_model and (t_env - model_save_time >= main_args.save_model_interval or model_save_time == 0):                    
            save_path = os.path.join(main_args.save_dir,str(t_env))
            logger.console_logger.info("Timesteps ={}, Saving Model at {}".format(t_env, save_path))
            os.makedirs(save_path, exist_ok=True)
            logger.console_logger.info("Saving models to {}".format(save_path))
            learner.save_models(save_path)
            model_save_time = t_env
            
            if main_args.pretrain_vqvae:
                learner.mac.vqvae_model.save_embedding_to_npy(save_path)


def args_sanity_check(config, _log):
    if config["use_cuda"] and not th.cuda.is_available():
        config["use_cuda"] = False
        _log.warning("CUDA flag use_cuda was switched OFF automatically because no CUDA devices are available!")

    if config["test_nepisode"] < config["batch_size_run"]:
        config["test_nepisode"] = config["batch_size_run"]
    else:
        config["test_nepisode"] = (config["test_nepisode"] // config["batch_size_run"]) * config["batch_size_run"]

    return config


def save_skill(learner, train_tasks, task2offlinedata, main_args, logger, t_max):
    for task in train_tasks:
        data = deepcopy(task2offlinedata[task].buffer.data)
        filled = th.from_numpy(data['filled']) # numpy 2 tensor
        max_t_fiiled = th.sum(filled, 1).max(0)[0].item()

        for k, v in data.items():
            data[k]=th.from_numpy(v)
        batch = OfflineSample(data, data['filled'].shape[0], max_t_fiiled, device='cuda')
        actions = batch["actions"][:, :]

        mac_out = []
        with th.no_grad():
            learner.mac.init_hidden(batch.batch_size, task)
            for t in range(batch.max_seq_length):
                agent_outs = learner.mac.forward_skill(batch, t=t, task=task, actions=actions[:, t, :])
                mac_out.append(agent_outs)
            mac_out = th.stack(mac_out, dim=1)  # Concat over time
            max_skill = mac_out.max(dim=-1)[1]
          
        data['skill_max'] = max_skill
        data['skill'] = mac_out
        batch_w_sk = OfflineSample(data, data['filled'].shape[0], max_t_fiiled, device='cuda')
        
        offline_skill_saver = DataSaver(os.path.join(main_args.offline_data_folder_4_save_vae + "_debug" if main_args.debug else main_args.offline_data_folder_4_save_vae, \
            task, main_args.offline_data_quality,"skill_dim_"+str(main_args.skill_dim), main_args.unique_token, "{}".format(t_max)), max_size=1)
        
        offline_skill_saver.append({
                    k: batch_w_sk[k].clone().cpu() for k in batch_w_sk.data.keys()})

        savepath = offline_skill_saver.close()
        logger.console_logger.info("Save offline buffer with skill to {}".format(savepath))
        

def save_tactic_from_vqvae(learner, train_tasks, task2offlinedata, main_args, logger, t_env):
    for task in train_tasks:
        data = deepcopy(task2offlinedata[task].buffer.data)
        batch_size, T, agent_num = data['obs'].shape[0], data['obs'].shape[1], data['obs'].shape[2]
        
        filled = th.from_numpy(data['filled']) 
        max_t_fiiled = th.sum(filled, 1).max(0)[0].item()

        for k, v in data.items():
            data[k]=th.from_numpy(v)
        batch = OfflineSample(data, batch_size, max_t_fiiled, device='cuda')
        actions = batch["actions"][:, :]

        global_skill_ids = []
        learner.mac.init_hidden(batch_size, task)
        for t in range(batch.max_seq_length):     
            indices, hidden_states_enc = learner.mac.vqvae_model.forward_indices(batch, t, task, learner.mac.hidden_states_enc)
            learner.mac.hidden_states_enc = hidden_states_enc
            
            global_skill_ids.append(indices)

        global_skill_ids = th.stack(global_skill_ids, dim=1).reshape(batch_size,agent_num, batch.max_seq_length, -1)\
                                                            .permute(0,2,1,3)  # (batch_size, batch.max_seq_length, agent_num, 1)
        data['global_skill_id'] = global_skill_ids  
  
        batch_w_global_sk = OfflineSample(data, batch_size, max_t_fiiled, device='cuda')
        
        offline_skill_saver = DataSaver(os.path.join(main_args.offline_data_folder4_gl_skill + "_debug" if main_args.debug else main_args.offline_data_folder4_gl_skill, \
            task, main_args.offline_data_quality,"gl_skill_dim_"+str(main_args.vqvae_K), main_args.unique_token, "{}".format(t_env)), max_size=1)
        
        learner.mac.vqvae_model.save_embedding_to_npy(os.path.join(main_args.offline_data_folder4_gl_skill + "_debug" if main_args.debug else main_args.offline_data_folder4_gl_skill, \
            task, main_args.offline_data_quality,"gl_skill_dim_"+str(main_args.vqvae_K), main_args.unique_token, "{}".format(t_env)))
        
        offline_skill_saver.append({
                    k: batch_w_global_sk[k].clone().cpu() for k in batch_w_global_sk.data.keys()})

        savepath = offline_skill_saver.close()
        logger.console_logger.info("Save offline buffer with skill to {}".format(savepath))


        
